import torch
from torch import nn
from torch.nn.functional import normalize, conv_transpose2d


def circulant(tensor, dim, height, stride=1):
    """get a circulant version of the tensor along the {dim} dimension.
    The additional axis is appended as the last dimension.
    E.g. tensor=[0,1,2], dim=0 --> [[0,1,2],[2,0,1],[1,2,0]]"""
    S = tensor.shape[dim]
    tmp = torch.cat([tensor.flip((dim,)), torch.narrow(tensor.flip((dim,)), dim=dim, start=0, length=height)], dim=dim)
    return tmp.unfold(dim, S, stride).flip((-1,))


def kernel_toeplitz(kernel, data_shape, device, stride=1):
    '''
    This kernel toeplitz
    :param data_shape:
    :param kernel: kernel format as (k_batch,out_channel,k_h,k_w)
    :param stride:
    :return:
    '''
    (v_batch, in_channel, v_h, v_w) = data_shape
    (k_batch, _, k_h, k_w) = kernel.shape
    if ((v_h - k_h) < stride or (v_w - k_w) < stride) and v_h != k_h:  # exception for that v_h == k_h, which is linear
        raise ValueError("kernel < stride")
    horizontal_move = int((v_w - k_w) / stride + 1)
    vertical_move = int((v_h - k_h) / stride + 1)
    # moving_steps =  int(horizontal_move * vertical_move)    # eliminate the init state
    vw_tensor = torch.zeros((k_batch, in_channel, k_h, (v_w - k_w))).to(device)
    k_vector = torch.cat((kernel, vw_tensor), dim=3).flatten(start_dim=2)
    moving_space = (v_h * v_w) - k_vector.shape[-1]
    k_vector = torch.cat((k_vector, torch.zeros(k_batch, in_channel, moving_space, device=device)),
                         dim=2).contiguous().view(k_batch, (v_w * v_h * in_channel))
    bool_index = torch.tensor((([1] + [0] * (stride - 1)) * (horizontal_move - 1) + [1] + (
            [0] * (k_w - 1 + v_w * (stride - 1)))) * vertical_move).to(device)  #todo: check this once needed
    # res = circulant(k_vector,dim=1, height= v_w * v_h - 1, stride=1)
    res = circulant(k_vector, dim=1, height=torch.nonzero(bool_index)[-1].item(), stride=1)
    return res[:, torch.where(bool_index == 1)[0], :]


def pooling_toeplitz(data, kernel_size, stride, device):
    '''
    the pooling toeplitz get the toeplitz for data's pooling operation
    it only works for squared matrix where d_m % stride = 0 and d_n k_m % stride = 0
    :param data:
    :param kernel_size:
    :param stride:
    :return:
    '''
    (v_batch, in_channel, v_h, v_w) = data.shape
    (k_h, k_w) = kernel_size
    if v_h != v_w or k_h != k_w or v_w % stride != 0 or k_w % stride != 0:
        return None
    toeplitz_w = v_h * v_w * in_channel
    height = toeplitz_w // stride
    kernel_vector = torch.cat([torch.ones((k_h, k_w)), torch.zeros(k_h, v_w - k_w)], dim=1).to(device=device)
    pattern = torch.cat([kernel_vector.flatten(start_dim=0), torch.zeros(toeplitz_w - v_w * k_h).to(device=device)]) * (
            1 / (k_h * k_w))
    toeplitz = circulant(pattern, dim=0, height=(height - 1) * stride, stride=stride)
    bool_index = torch.tensor(([1] * (v_h // stride) + [0] * (v_h // stride)) * (height // (2 * v_h // stride)))
    return toeplitz[torch.where(bool_index == 1)[0], :]


def data_toeplitz(data, kernel, stride=1):
    '''
    the pooling toeplitz get the toeplitz for input and assume the kernel flattened to a vector
    :param data:
    :param kernel:
    :param stride:
    :return:
    '''

    (out_channel, in_channel_k, k_h, k_v) = kernel.shape
    (d_batch, in_channel, d_h, d_v) = data.shape
    if in_channel_k != in_channel:
        raise "channel mismatch"
    toeplitz = torch.nn.functional.unfold(data, (k_h, k_v), dilation=(1, 1), stride=stride)
    return toeplitz.transpose(1, 2)


def model_require(model, layer):
    if layer - 1 == 2:
        return nn.AvgPool2d(kernel_size=2)
    elif layer - 1 == 1:
        return nn.AvgPool2d(kernel_size=(2, 2))


def power_iteration(weight, n_iter, eps, coeff, matrix=True, name=""):
    h, w = weight.size()
    u = normalize(weight.new_empty(h).normal_(0, 1), dim=0, eps=eps)
    v = normalize(weight.new_empty(w).normal_(0, 1), dim=0, eps=eps)
    for _ in range(n_iter):
        v = normalize(torch.mv(weight.t(), u), dim=0, eps=eps, out=v)
        u = normalize(torch.mv(weight, v), dim=0, eps=eps, out=u)
    if n_iter > 0:
        u = u.clone()
        v = v.clone()

    sigma = torch.dot(u, torch.mv(weight, v))
    factor = torch.max(torch.ones(1).to(weight.device), sigma / coeff)
    if not matrix:
        return factor
    else:
        weight = weight / factor
        return weight


def cnn_power_iterations(n_power_iterations, weight, stride, epsilon, input_dim, coeff,device):
    # randomly init u and v
    num_input_dim = input_dim[0] * input_dim[1] * input_dim[2] * input_dim[3]
    v = normalize(torch.randn(num_input_dim), dim=0, eps=epsilon).to(device=device)

    padding = 0
    # forward call to infer the shape
    u = nn.functional.conv2d(v.view(input_dim), weight, stride=stride, padding=padding,
                             bias=None).to(device=device)
    out_shape = u.shape
    num_output_dim = out_shape[0] * out_shape[1] * out_shape[2] * out_shape[3]
    # overwrite u with random init
    u = normalize(torch.randn(num_output_dim), dim=0, eps=epsilon).to(device=device)

    for _ in range(n_power_iterations):
        v_s = nn.functional.conv_transpose2d(u.view(out_shape), weight, stride=stride,
                                             padding=0, output_padding=0)
        # Note: out flag for in-place changes
        v = normalize(v_s.view(-1), dim=0, eps=epsilon, out=v)

        u_s = nn.functional.conv2d(v.view(input_dim), weight, stride=stride, padding=0,
                                   bias=None)
        u = normalize(u_s.view(-1), dim=0, eps=epsilon, out=u)
    if n_power_iterations > 0:
        # See above on why we need to clone
        u = u.clone()
        v = v.clone()
    weight_v = nn.functional.conv2d(v.view(input_dim), weight, stride=stride, padding=0,
                                    bias=None)
    weight_v = weight_v.view(-1)
    sigma = torch.dot(u.view(-1), weight_v)
    # enforce spectral norm only as constraint
    factorReverse = torch.max(torch.ones(1).to(weight.device),
                              sigma / coeff)

    # rescaling
    weight = weight / (factorReverse + 1e-5)  # for stability
    return weight

def normalization(x,threshold=2,device="cuda"):
    if threshold == 0:
        return x
    shape = x.shape
    batch_size = shape[0]
    x_vector = x.reshape(batch_size, -1)
    norm = torch.linalg.vector_norm(x_vector, ord=2, dim=1)
    factor = torch.max(torch.ones(1).to(device), norm / threshold)
    return (x_vector.T / factor).T.reshape(shape)